{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "087d1891",
   "metadata": {},
   "source": [
    "# Model validation using offline batch inference\n",
    "\n",
    "<div align=\"left\">\n",
    "<a target=\"_blank\" href=\"https://console.anyscale.com/\"><img src=\"https://img.shields.io/badge/🚀 Run_on-Anyscale-9hf\"></a>&nbsp;\n",
    "<a href=\"https://github.com/anyscale/e2e-xgboost\" role=\"button\"><img src=\"https://img.shields.io/static/v1?label=&amp;message=View%20On%20GitHub&amp;color=586069&amp;logo=github&amp;labelColor=2f363d\"></a>&nbsp;\n",
    "</div>\n",
    "\n",
    "This tutorial executes a batch inference workload that connects the following heterogeneous workloads:\n",
    "- Distributed read from cloud storage\n",
    "- Distributed preprocessing\n",
    "- Parallel batch inference\n",
    "- Distributed aggregation of summary metrics\n",
    "\n",
    "Note that this tutorial fetches the pre-trained model artifacts from the [Distributed training of an XGBoost model](./01-Distributed_Training.ipynb) tutorial.\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-xgboost/refs/heads/main/images/batch_inference.png\" width=800>\n",
    "\n",
    "The preceding figure illustrates how Ray Data can concurrently process different chunks of data at various stages of the pipeline. This parallel execution maximizes resource utilization and throughput.\n",
    "\n",
    "Note that this diagram is a simplification for various reasons:\n",
    "\n",
    "* Backpressure mechanisms may throttle upstream operators to prevent overwhelming downstream stages\n",
    "* Dynamic repartitioning often occurs as data moves through the pipeline, changing block counts and sizes\n",
    "* Available resources change as the cluster autoscales\n",
    "* System failures may disrupt the clean sequential flow shown in the diagram"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59218309",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert\"> <b> Ray Data streaming execution</b> \n",
    "\n",
    "❌ **Traditional batch execution**, non-streaming like Spark without pipelining and SageMaker Batch Transform:\n",
    "- Reads the entire dataset into memory or a persistent intermediate format\n",
    "- Only then starts applying transformations like .map, .filter, etc.\n",
    "- Higher memory pressure and startup latency\n",
    "\n",
    "✅ **Streaming execution** with Ray Data:\n",
    "- Starts processing chunks (\"blocks\") as they're loaded without needing to wait for entire dataset to load\n",
    "- Reduces memory footprint, no out-of-memory errors, and speeds up time to first output\n",
    "- Increases resource utilization by reducing idle time\n",
    "- Enables online-style inference pipelines with minimal latency\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-xgboost/refs/heads/main/images/streaming.gif\" width=700>\n",
    "\n",
    "**Note**: Ray Data isn't a real-time stream processing engine like Flink or Kafka Streams. Instead, it's batch processing with streaming execution, which is especially useful for iterative ML workloads, ETL pipelines, and preprocessing before training or inference. Ray typically has a [**2-17x throughput improvement**](https://www.anyscale.com/blog/offline-batch-inference-comparing-ray-apache-spark-and-sagemaker#-results-of-throughput-from-experiments) over solutions like Spark and SageMaker Batch Transform.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f5493d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53176de1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enable importing from dist_xgboost package.\n",
    "import os\n",
    "import sys\n",
    "\n",
    "sys.path.append(os.path.abspath(\"..\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a1b35e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enable Ray Train v2.\n",
    "os.environ[\"RAY_TRAIN_V2_ENABLED\"] = \"1\"\n",
    "# Now it's safe to import from ray.train."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd3d72bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ray\n",
    "import dist_xgboost\n",
    "\n",
    "# Initialize Ray with the dist_xgboost package.\n",
    "ray.init(runtime_env={\"py_modules\": [dist_xgboost]})\n",
    "\n",
    "# Configure Ray Data logging.\n",
    "ray.data.DataContext.get_current().enable_progress_bars = False\n",
    "ray.data.DataContext.get_current().print_on_execution_start = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f8fa9ff",
   "metadata": {},
   "source": [
    "## Validating the model using Ray Data\n",
    "The previous tutorial, [Distributed Training with XGBoost](./01-Distributed_Training.ipynb), trained an XGBoost model and stored it in the MLflow artifact storage. In this step, use it to make predictions on the hold-out test set.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7816f41",
   "metadata": {},
   "source": [
    "### Data ingestion\n",
    "\n",
    "Load the test dataset using the same procedure as before:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "871a792a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-16 21:14:42,328\tINFO worker.py:1660 -- Connecting to existing Ray cluster at address: 10.0.23.200:6379...\n",
      "2025-04-16 21:14:42,338\tINFO worker.py:1843 -- Connected to Ray cluster. View the dashboard at \u001b[1m\u001b[32mhttps://session-1kebpylz8tcjd34p4sv2h1f9tg.i.anyscaleuserdata.com \u001b[39m\u001b[22m\n",
      "2025-04-16 21:14:42,343\tINFO packaging.py:575 -- Creating a file package for local module '/home/ray/default/e2e-xgboost/dist_xgboost'.\n",
      "2025-04-16 21:14:42,346\tINFO packaging.py:367 -- Pushing file package 'gcs://_ray_pkg_fbb1935a37eb6438.zip' (0.02MiB) to Ray cluster...\n",
      "2025-04-16 21:14:42,347\tINFO packaging.py:380 -- Successfully pushed file package 'gcs://_ray_pkg_fbb1935a37eb6438.zip'.\n",
      "2025-04-16 21:14:42,347\tINFO packaging.py:367 -- Pushing file package 'gcs://_ray_pkg_534f9f38a44d4c15f21856eec72c3c338db77a6b.zip' (0.08MiB) to Ray cluster...\n",
      "2025-04-16 21:14:42,348\tINFO packaging.py:380 -- Successfully pushed file package 'gcs://_ray_pkg_534f9f38a44d4c15f21856eec72c3c338db77a6b.zip'.\n",
      "2025-04-16 21:14:44,609\tINFO dataset.py:2809 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'mean radius': 15.7,\n",
       "  'mean texture': 20.31,\n",
       "  'mean perimeter': 101.2,\n",
       "  'mean area': 766.6,\n",
       "  'mean smoothness': 0.09597,\n",
       "  'mean compactness': 0.08799,\n",
       "  'mean concavity': 0.06593,\n",
       "  'mean concave points': 0.05189,\n",
       "  'mean symmetry': 0.1618,\n",
       "  'mean fractal dimension': 0.05549,\n",
       "  'radius error': 0.3699,\n",
       "  'texture error': 1.15,\n",
       "  'perimeter error': 2.406,\n",
       "  'area error': 40.98,\n",
       "  'smoothness error': 0.004626,\n",
       "  'compactness error': 0.02263,\n",
       "  'concavity error': 0.01954,\n",
       "  'concave points error': 0.009767,\n",
       "  'symmetry error': 0.01547,\n",
       "  'fractal dimension error': 0.00243,\n",
       "  'worst radius': 20.11,\n",
       "  'worst texture': 32.82,\n",
       "  'worst perimeter': 129.3,\n",
       "  'worst area': 1269.0,\n",
       "  'worst smoothness': 0.1414,\n",
       "  'worst compactness': 0.3547,\n",
       "  'worst concavity': 0.2902,\n",
       "  'worst concave points': 0.1541,\n",
       "  'worst symmetry': 0.3437,\n",
       "  'worst fractal dimension': 0.08631,\n",
       "  'target': 0}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from ray.data import Dataset\n",
    "\n",
    "\n",
    "def prepare_data() -> tuple[Dataset, Dataset, Dataset]:\n",
    "    \"\"\"Load and split the dataset into train, validation, and test sets.\"\"\"\n",
    "    dataset = ray.data.read_csv(\"s3://anonymous@air-example-data/breast_cancer.csv\")\n",
    "    seed = 42\n",
    "    train_dataset, rest = dataset.train_test_split(test_size=0.3, shuffle=True, seed=seed)\n",
    "    # 15% for validation, 15% for testing.\n",
    "    valid_dataset, test_dataset = rest.train_test_split(test_size=0.5, shuffle=True, seed=seed)\n",
    "    return train_dataset, valid_dataset, test_dataset\n",
    "\n",
    "\n",
    "_, _, test_dataset = prepare_data()\n",
    "# Use `take()` to trigger execution because Ray Data uses lazy evaluation mode.\n",
    "test_dataset.take(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30093e79",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-block alert\"> <b>💡 Ray Data best practices</b>\n",
    "\n",
    "- **Use `materialize()` during development**: The `materialize()` method executes and stores the dataset in Ray's shared memory object store. This behavior creates a checkpoint so future operations can start from this point instead of rerunning all operations from scratch.\n",
    "- **Choose appropriate shuffling strategies**: Ray Data provides various [shuffling strategies](https://docs.ray.io/en/latest/data/shuffling-data.html) including local shuffles and per-epoch shuffles. You need to shuffle this dataset because the original data groups items by class.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd7eb41a",
   "metadata": {},
   "source": [
    "Next, transform the input data the same way you did during training. Fetch the preprocessor from the artifact registry:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "136097f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "from dist_xgboost.constants import preprocessor_fname\n",
    "from dist_xgboost.data import get_best_model_from_registry\n",
    "\n",
    "best_run, best_artifacts_dir = get_best_model_from_registry()\n",
    "\n",
    "with open(os.path.join(best_artifacts_dir, preprocessor_fname), \"rb\") as f:\n",
    "    preprocessor = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d851437",
   "metadata": {},
   "source": [
    "Now define the transformation step in the Ray Data pipeline. Instead of processing each item individually with `.map()`, use Ray Data's [`map_batches`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map_batches.html) method to process entire batches at once, which is much more efficient:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca4d87eb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'mean radius': 0.4202879281965173, 'mean texture': 0.2278148207774012, 'mean perimeter': 0.35489846800755104, 'mean area': 0.29117590184541364, 'mean smoothness': -0.039721410464208406, 'mean compactness': -0.30321758777095337, 'mean concavity': -0.2973304995033593, 'mean concave points': 0.05629912285695481, 'mean symmetry': -0.6923528276633714, 'mean fractal dimension': -1.0159489469979848, 'radius error': -0.1244372811541358, 'texture error': -0.1073358496664629, 'perimeter error': -0.2253381140213174, 'area error': 0.001804996358367429, 'smoothness error': -0.8087740189276656, 'compactness error': -0.1437977993323648, 'concavity error': -0.3926326901399853, 'concave points error': -0.34157926393517024, 'symmetry error': -0.5862955941365042, 'fractal dimension error': -0.496152478599194, 'worst radius': 0.7695260874215265, 'worst texture': 1.1287525414418031, 'worst perimeter': 0.6310282171135395, 'worst area': 0.6506421499178707, 'worst smoothness': 0.39052158034274154, 'worst compactness': 0.6735246675401986, 'worst concavity': 0.06668871795848759, 'worst concave points': 0.5859784499947507, 'worst symmetry': 0.8525444557664399, 'worst fractal dimension': 0.14066370266791928, 'target': 0}\n"
     ]
    }
   ],
   "source": [
    "def transform_with_preprocessor(batch_df, preprocessor):\n",
    "    # The preprocessor doesn't include the `target` column,\n",
    "    # so remove it temporarily, then add it back.\n",
    "    target = batch_df.pop(\"target\")\n",
    "    transformed_features = preprocessor.transform_batch(batch_df)\n",
    "    transformed_features[\"target\"] = target\n",
    "    return transformed_features\n",
    "\n",
    "\n",
    "# Apply the transformation to each batch.\n",
    "test_dataset = test_dataset.map_batches(\n",
    "    transform_with_preprocessor,\n",
    "    fn_kwargs={\"preprocessor\": preprocessor},\n",
    "    batch_format=\"pandas\",\n",
    "    batch_size=1000,\n",
    ")\n",
    "\n",
    "test_dataset.show(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21e21449",
   "metadata": {},
   "source": [
    "### Load the trained model\n",
    "\n",
    "Now that you've defined the preprocessing pipeline, you're ready to run batch inference. Load the model from the artifact registry:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5859a37",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ray.train import Checkpoint\n",
    "from ray.train.xgboost import RayTrainReportCallback\n",
    "\n",
    "checkpoint = Checkpoint.from_directory(best_artifacts_dir)\n",
    "model = RayTrainReportCallback.get_model(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c41b340",
   "metadata": {},
   "source": [
    "### Run batch inference\n",
    "\n",
    "Next, run the inference step. To avoid repeatedly loading the model for each batch, define a reusable class that can use the same XGBoost model for different batches:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29ef57ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import xgboost\n",
    "\n",
    "from dist_xgboost.data import load_model_and_preprocessor\n",
    "\n",
    "\n",
    "class Validator:\n",
    "    def __init__(self):\n",
    "        _, self.model = load_model_and_preprocessor()\n",
    "\n",
    "    def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:\n",
    "        # Remove the target column for inference.\n",
    "        target = batch.pop(\"target\")\n",
    "        dmatrix = xgboost.DMatrix(batch)\n",
    "        predictions = self.model.predict(dmatrix)\n",
    "\n",
    "        results = pd.DataFrame({\"prediction\": predictions, \"target\": target})\n",
    "        return results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5be7083f",
   "metadata": {},
   "source": [
    "Now parallelize inference across replicas of the model by processing data in batches:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3690ceb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-16 21:14:56,496\tWARNING actor_pool_map_operator.py:287 -- To ensure full parallelization across an actor pool of size 4, the Dataset should consist of at least 4 distinct blocks. Consider increasing the parallelism when creating the Dataset.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'prediction': 0.031001044437289238, 'target': 0}\n"
     ]
    }
   ],
   "source": [
    "test_predictions = test_dataset.map_batches(\n",
    "    Validator,\n",
    "    concurrency=4,  # Number of model replicas.\n",
    "    batch_format=\"pandas\",\n",
    ")\n",
    "\n",
    "test_predictions.show(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a6dcc80",
   "metadata": {},
   "source": [
    "### Calculate evaluation metrics\n",
    "\n",
    "Now that you have predictions, evaluate the model's accuracy, precision, recall, and F1-score. Calculate the number of true positives, true negatives, false positives, and false negatives across the test subset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7e2205b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "\n",
    "def confusion_matrix_batch(batch, threshold=0.5):\n",
    "    # Apply a threshold to get binary predictions.\n",
    "    batch[\"prediction\"] = (batch[\"prediction\"] > threshold).astype(int)\n",
    "\n",
    "    result = {}\n",
    "    cm = confusion_matrix(batch[\"target\"], batch[\"prediction\"], labels=[0, 1])\n",
    "    result[\"TN\"] = cm[0, 0]\n",
    "    result[\"FP\"] = cm[0, 1]\n",
    "    result[\"FN\"] = cm[1, 0]\n",
    "    result[\"TP\"] = cm[1, 1]\n",
    "    return pd.DataFrame(result, index=[0])\n",
    "\n",
    "\n",
    "test_results = test_predictions.map_batches(confusion_matrix_batch, batch_format=\"pandas\", batch_size=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09edcdb8",
   "metadata": {},
   "source": [
    "Finally, aggregate the confusion matrix results from all batches to get the global counts. This step materializes the dataset and executes all previously declared lazy transformations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a5dbbca",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-04-16 21:15:01,144\tWARNING actor_pool_map_operator.py:287 -- To ensure full parallelization across an actor pool of size 4, the Dataset should consist of at least 4 distinct blocks. Consider increasing the parallelism when creating the Dataset.\n"
     ]
    }
   ],
   "source": [
    "# Sum all confusion matrix values across batches.\n",
    "cm_sums = test_results.sum([\"TN\", \"FP\", \"FN\", \"TP\"])\n",
    "\n",
    "# Extract confusion matrix components.\n",
    "tn = cm_sums[\"sum(TN)\"]\n",
    "fp = cm_sums[\"sum(FP)\"]\n",
    "fn = cm_sums[\"sum(FN)\"]\n",
    "tp = cm_sums[\"sum(TP)\"]\n",
    "\n",
    "# Calculate metrics.\n",
    "accuracy = (tp + tn) / (tp + tn + fp + fn)\n",
    "precision = tp / (tp + fp) if (tp + fp) > 0 else 0\n",
    "recall = tp / (tp + fn) if (tp + fn) > 0 else 0\n",
    "f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0\n",
    "\n",
    "metrics = {\"precision\": precision, \"recall\": recall, \"f1\": f1, \"accuracy\": accuracy}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ae128ce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation results:\n",
      "precision: 0.9464\n",
      "recall: 0.9815\n",
      "f1: 0.9636\n",
      "accuracy: 0.9535\n"
     ]
    }
   ],
   "source": [
    "print(\"Validation results:\")\n",
    "for key, value in metrics.items():\n",
    "    print(f\"{key}: {value:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16c8ec6b",
   "metadata": {},
   "source": [
    "The following is the expected output:\n",
    "\n",
    "```\n",
    "Validation results:\n",
    "precision: 0.9574\n",
    "recall: 1.0000\n",
    "f1: 0.9783\n",
    "accuracy: 0.9767\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd496163",
   "metadata": {},
   "source": [
    "## Observability\n",
    "\n",
    "Ray Data provides built-in observability features to help you monitor and debug data processing pipelines:\n",
    " \n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-xgboost/refs/heads/main/images/ray_data_observability.png\" width=800>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5c2340a",
   "metadata": {},
   "source": [
    "### Production deployment\n",
    "\n",
    "You can wrap the training workload as a production-grade [Anyscale Job](https://docs.anyscale.com/platform/jobs/). See the [API ref](https://docs.anyscale.com/reference/job-api/):\n",
    "\n",
    "```bash\n",
    "# Production batch job.\n",
    "anyscale job submit --name=validate-xboost-breast-cancer-model \\\n",
    "  --containerfile=\"${WORKING_DIR}/containerfile\" \\\n",
    "  --working-dir=\"${WORKING_DIR}\" \\\n",
    "  --exclude=\"\" \\\n",
    "  --max-retries=0 \\\n",
    "  -- python dist_xgboost/infer.py\n",
    "```\n",
    "\n",
    "Note that in order for this command to succeed, first configure MLflow to store the artifacts in storage that's readable across clusters. Anyscale offers a variety of storage options that work out of the box, such as a [default storage bucket](https://docs.anyscale.com/configuration/storage/#anyscale-default-storage-bucket), as well as [automatically mounted network storage](https://docs.anyscale.com/configuration/storage/) shared at the cluster, user, and cloud levels. You could also set up your own network mounts or storage buckets."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa1095f1",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "In this tutorial, you:\n",
    "\n",
    "1. Loaded a test dataset using distributed reads from cloud storage\n",
    "2. Transformed the dataset in a streaming fashion with the same preprocessor used during training\n",
    "3. Treated a validation pipeline to:\n",
    "   - Make predictions on the test data using multiple model replicas\n",
    "   - Calculate confusion matrix components for each batch\n",
    "   - Aggregate results across all batches\n",
    "4. Computed key performance metrics, like precision, recall, F1-score, and accuracy\n",
    "\n",
    "The same code can efficiently run on terabyte-scale datasets without modifications using Ray Data's distributed processing capabilities.\n",
    "\n",
    "The next tutorial shows how to serve this XGBoost model for online inference using Ray Serve."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
